{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deeply Embedded DSLs in Python\n",
    "\n",
    "Taking language from [\"Deep and Shallow Embeddings\"](https://www.cs.ox.ac.uk/people/jeremy.gibbons/publications/embedding-short.pdf) by Jeremy Gibbons, domain specific languages thare are embedded in a host language, come in two flavors: shallow and deep embedding.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Shallow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataclasses\n",
    "import typing\n",
    "\n",
    "@dataclasses.dataclass\n",
    "class NaturalShallow:\n",
    "    value: int\n",
    "\n",
    "    def __add__(self, other: \"NaturalShallow\") -> \"NaturalShallow\":\n",
    "        return NaturalShallow(self.value + other.value)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NaturalShallow(value=9)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NaturalShallow(3) + NaturalShallow(4) + NaturalShallow(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataclasses\n",
    "import typing\n",
    "\n",
    "\n",
    "\n",
    "@dataclasses.dataclass\n",
    "class NaturalAdd:\n",
    "    left: \"NaturalDeep\"\n",
    "    right: \"NaturalDeep\"\n",
    "\n",
    "    def execute(self) -> int:\n",
    "        return self.left.execute() + self.right.execute()\n",
    "\n",
    "    \n",
    "@dataclasses.dataclass\n",
    "class NaturalDeep:\n",
    "    value: typing.Union[int, NaturalAdd]\n",
    "\n",
    "    def __add__(self, other: \"NaturalDeep\") -> \"NaturalAdd\":\n",
    "        return NaturalDeep(NaturalAdd(self, other))\n",
    "    \n",
    "    def execute(self) -> int:\n",
    "        if isinstance(self.value, int):\n",
    "            return self.value\n",
    "        return self.value.execute()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NaturalDeep(value=NaturalAdd(left=NaturalDeep(value=NaturalAdd(left=NaturalDeep(value=3), right=NaturalDeep(value=4))), right=NaturalDeep(value=1)))"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NaturalDeep(3) + NaturalDeep(4) + NaturalDeep(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(NaturalDeep(3) + NaturalDeep(4) + NaturalDeep(1)).execute()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the shallow embedding does the execution when the function is called. This is like most APis in built into Python, including NumPy's. It is nice and simple to understand.\n",
    "\n",
    "But what do we get by having a deep embedding? Well one thing we get is the ability to define different interpretions of some execution. Here we have `execute` defined, but we also get (for free) a string representation of the execution. So already this can be useful when understanding what the API will do.\n",
    "\n",
    "But the real power we get here is that we turn the problem from \"What do I do when I call this function?\" to \"I have this data structure representing some call. Now what can I do with it?\" So we can see how we get the same sort of power Lisp users talk about, and we can easily add macro-like behavior without any changes to the language.\n",
    "\n",
    "For example, we could now go into the execution data structures and move things around as we wish, with all the power of Python itself."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A library like multipledispatch could be useful in creating a shallowly embedded DSL, but it won't get you all the way to a deeply embedded DSL, because it still ties the executiong of a function to it's call. **The whole point of a deeply embedded DSL is to break up those two steps, where the user says \"I want to call this function with these arguments\" and the actual execution of that function with those arguments.** By putting some abstraction between those steps, we can do a lot of interesting things, like optimize the expression, translate it to some other computation backend, or try to prove properties about it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What we need"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Libraries like `ibis` or `dask.array` are using this strategy effectively today to take one API (Pandas, NumPy) and map it into a different execution context (SQL, computation graph). \n",
    "\n",
    "On the other end, efforts like NumPy's API exposure is about enabling more libraries like this to exist, by allowing libraries to override how NumPy methods are called.\n",
    "\n",
    "What is missing? Well it's hard for these packages to work with each other. They exist in seperate worlds and can't share code. Numba, for example, has some useful code on optimizing certain NumPy calls. There is no reason those couldn't be reused when you create NumPy expression in `dask.array`. Why aren't they shared? Well they each use different representations of what a NumPy expression is. \n",
    "\n",
    "So, I am curious what a shared based could like to support these different use cases and encourage best practices, innovation, and sharing of work.\n",
    "\n",
    "In the NumPy case, this would mean first being able to turn a number of NumPy calls into a deeply embedded representation, that would be agnostic to how it would be executed. Then there would be libraries that would convert between this representation and Numba IR or Dask expressions. So what is at the core?\n",
    "\n",
    "We need a way to express a call to function, without actually calling it. So let's create that:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class Call:\n",
    "    function: typing.Callable\n",
    "    args: typing.Tuple[\"NaturalDeep2\", ...]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's rewrite our natural number implementation using it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class NaturalDeep2:\n",
    "    value: typing.Union[int, Call]\n",
    "\n",
    "    def __add__(self, other: \"NaturalDeep2\") -> \"Call\":\n",
    "        return NaturalDeep2(Call(NaturalDeep2.__add__, (self, other)))\n",
    "    \n",
    "    def execute(self) -> int:\n",
    "        if isinstance(self.value, int):\n",
    "            return self.value\n",
    "        return self.value.args[0].execute() + self.value.args[1].execute()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NaturalDeep2(value=Call(function=<function NaturalDeep2.__add__ at 0x10f01e840>, args=(NaturalDeep2(value=Call(function=<function NaturalDeep2.__add__ at 0x10f01e840>, args=(NaturalDeep2(value=123), NaturalDeep2(value=456)))), NaturalDeep2(value=123))))"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "NaturalDeep2(123) + NaturalDeep2(456) + NaturalDeep2(123) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "702"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(NaturalDeep2(123) + NaturalDeep2(456) + NaturalDeep2(123)).execute()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This looks good. However, let's seperate the execution from the class, so we can add multuple versions without modifying the class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclasses.dataclass\n",
    "class NaturalDeep3:\n",
    "    value: typing.Union[int, Call]\n",
    "\n",
    "    def __add__(self, other: \"NaturalDeep2\") -> \"Call\":\n",
    "        return NaturalDeep2(Call(NaturalDeep2.__add__, (self, other)))\n",
    "\n",
    "\n",
    "def execute(natural: NaturalDeep3) -> int:\n",
    "    if isinstance(natural.value, int):\n",
    "        return natural.value\n",
    "    return execute(natural.value.args[0]) + execute(natural.value.args[1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "execute(NaturalDeep3(1) + NaturalDeep3(2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can add another function, like `cost` that computes the expected cost of the computation before it executes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cost(natural: NaturalDeep3) -> int:\n",
    "    if isinstance(natural.value, int):\n",
    "        return 0\n",
    "    return 1 + cost(natural.value.args[0]) + cost(natural.value.args[1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cost(NaturalDeep3(1) + NaturalDeep3(2) + NaturalDeep3(1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You might notice that both `cost` and `execute` share a similar structure, where they have a base case for the leaf node, and an accumalator over the children nodes. We can factor out this pattern with a general `fold` high order function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = typing.TypeVar(\"T\")\n",
    "\n",
    "def fold_natural(leaf_fn: typing.Callable[[int], T], call_fn: typing.Callable[[typing.Callable, typing.Iterable[T]], T], natural: NaturalDeep3) -> T:\n",
    "    if isinstance(natural.value, int):\n",
    "        return leaf_fn(natural.value)\n",
    "\n",
    "    partial_fold = lambda other_natural: fold_natural(leaf_fn, call_fn, other_natural)\n",
    "    return call_fn(natural.value.function, map(partial_fold, natural.value.args))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can define execute and cost using this higher order function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "\n",
    "cost2 = functools.partial(fold_natural, lambda _: 0, lambda fn, costs: 1 + sum(costs))\n",
    "evaluate2 = functools.partial(fold_natural, lambda v: v, lambda fn, values: sum(values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cost2(NaturalDeep3(1) + NaturalDeep3(2) + NaturalDeep3(1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate2(NaturalDeep3(10) + NaturalDeep3(2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! Now we are really making some headway in abstracting out the exection.\n",
    "\n",
    "To do this for real, we would have to make `evaluate` have some form of registration and multiple dispatch. We could do this many different ways, with pattern matching, replacement, or a more tradiotional approach of defining different handler functions for each operation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "Now there is another way we could defime DSLs. Which is how Numba and Autograph work. They don't build up expression, they analyze the CFG. Why? Because not all operations in Python are expressions.\n",
    "\n",
    "For example with the approach here, you can override function application, but you cannot override control flow. For example, there would be no way to implement something like this:\n",
    "\n",
    "```python\n",
    "x: NaturalDeep3\n",
    "\n",
    "if x < NaturalDeep3(10):\n",
    "    y = x - 100\n",
    "else:\n",
    "    y = x + 30\n",
    "```\n",
    "\n",
    "We could do it by introducing a special boolean type and an if expression, like this:\n",
    "\n",
    "```python\n",
    "y = (x < NaturalDeep3(10)).if_(x - 100, x + 30)\n",
    "```\n",
    "\n",
    "But generally translating between the first form and the second is pretty complicated!\n",
    "\n",
    "What if there is a return statement in the if clause? How would you compute a proper expression for this?\n",
    "\n",
    "This is the problem the [\"The 800 Pound Python in the Machine Learning Room\"](https://paperpile.com/app/p/f543bd75-2192-0a48-ae1d-c3d28c410439) paper tackles. It provides a function wrapper, that transforms the AST to remove\n",
    "control flow like that and turn it into function calls, so that you can override it.\n",
    "\n",
    "Generally, this problem comes down to translating between a control flow graph and an expression tree. This is where the [\"Optimizing compilation with the Value State Dependence Graph\"](https://paperpile.com/app/p/da3d91a8-5c51-0d76-ba27-847c8a7eb41a) paper comes into play. Our representation here in Python is a subset of  the \"Value State Dependence Graph\" representation.\n",
    "\n",
    "So what does this mean for a generic DSL library in Python? Well it's much easier to think about execution on an expression tree/graph, like the one we introduced. Pattern matching can be used to implement partial evaluation and use that for computation.\n",
    "\n",
    "But, it's nice to write imperative code and its the common idiom in python to write control flow heavy code.   But you might notice that actually in the NumPy world, it is common to use it's abstractions for control flow, not to use Python's native control flow.\n",
    "\n",
    "So as a start, we can just support that kind of API and forgot about going from a CFG. \n",
    "\n",
    "However eventually, we need to start with, and end up with, a CFG. Because if we wanna compile one of these expressions back to Python, or to another language, which isn't based on expression trees, then we will need one.\n",
    "\n",
    "So to be generally useful, we should figure out the right abstractions for translating these types of tree expressions to and from CFGs. This is not new work, it has been done in the compiler community. What it takes is understanding that work enough to understand what abstractions to bring over here to Python. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
